# -*- coding: utf-8 -*-
"""WGAN.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1_q3QKVeCbeMqHJdMJcjkvC0PLoyP17jB
"""

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import grad
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

# Hyperparameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lr = 1e-4
batch_size = 64
image_size = 64
channels_img = 1
z_dim = 100
num_epochs = 5
features_d = 64
features_g = 64
critic_iterations = 5
lambda_gp = 10

# Initialize generator and critic
class Generator(nn.Module):
    def __init__(self, z_dim, channels_img, features_g):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            # Input: N x z_dim x 1 x 1
            nn.ConvTranspose2d(z_dim, features_g * 8, 4, 1, 0),  # 4x4
            nn.BatchNorm2d(features_g * 8),
            nn.ReLU(),
            nn.ConvTranspose2d(features_g * 8, features_g * 4, 4, 2, 1),  # 8x8
            nn.BatchNorm2d(features_g * 4),
            nn.ReLU(),
            nn.ConvTranspose2d(features_g * 4, features_g * 2, 4, 2, 1),  # 16x16
            nn.BatchNorm2d(features_g * 2),
            nn.ReLU(),
            nn.ConvTranspose2d(features_g * 2, features_g, 4, 2, 1),  # 32x32
            nn.BatchNorm2d(features_g),
            nn.ReLU(),
            nn.ConvTranspose2d(features_g, channels_img, 4, 2, 1),  # 64x64
            nn.Tanh(),
        )

    def forward(self, x):
        return self.net(x)

class Critic(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Critic, self).__init__()
        self.net = nn.Sequential(
            # Input: N x channels_img x 64 x 64
            nn.Conv2d(channels_img, features_d, 4, 2, 1),  # 32x32
            nn.LeakyReLU(0.2),
            nn.Conv2d(features_d, features_d * 2, 4, 2, 1),  # 16x16
            nn.InstanceNorm2d(features_d * 2),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features_d * 2, features_d * 4, 4, 2, 1),  # 8x8
            nn.InstanceNorm2d(features_d * 4),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features_d * 4, features_d * 8, 4, 2, 1),  # 4x4
            nn.InstanceNorm2d(features_d * 8),
            nn.LeakyReLU(0.2),
            nn.Conv2d(features_d * 8, 1, 4, 1, 0),  # 1x1
        )

    def forward(self, x):
        return self.net(x)

def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

def gradient_penalty(critic, real, fake, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * alpha + fake * (1 - alpha)

    # Calculate critic scores
    mixed_scores = critic(interpolated_images)

    # Take the gradient of the scores with respect to the images
    gradient = grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

# Initialize models
gen = Generator(z_dim, channels_img, features_g).to(device)
critic = Critic(channels_img, features_d).to(device)
initialize_weights(gen)
initialize_weights(critic)

# Optimizers
opt_gen = optim.Adam(gen.parameters(), lr=lr, betas=(0.0, 0.9))
opt_critic = optim.Adam(critic.parameters(), lr=lr, betas=(0.0, 0.9))

# Load data
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize(
        [0.5 for _ in range(channels_img)],
        [0.5 for _ in range(channels_img)],
    ),
])

dataset = datasets.MNIST(root="dataset/", transform=transform, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# For tensorboard plotting
fixed_noise = torch.randn(32, z_dim, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0

gen.train()
critic.train()

# Training loop
for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.to(device)
        current_batch_size = real.shape[0] # Get current batch size

        # Train Critic: max E[critic(real)] - E[critic(fake)]
        for _ in range(critic_iterations):
            noise = torch.randn(current_batch_size, z_dim, 1, 1).to(device) # Use current_batch_size
            fake = gen(noise)
            critic_real = critic(real).reshape(-1)
            critic_fake = critic(fake.detach()).reshape(-1) # detach fake here
            gp = gradient_penalty(critic, real, fake, device=device)
            loss_critic = (
                -(torch.mean(critic_real) - torch.mean(critic_fake)) + lambda_gp * gp
            )
            critic.zero_grad()
            loss_critic.backward(retain_graph=True) # Added retain_graph=True
            opt_critic.step()

        # Train Generator: max E[critic(fake)] <-> min -E[critic(fake)]
        output = critic(fake).reshape(-1)
        loss_gen = -torch.mean(output)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Print losses occasionally and print to tensorboard
        if batch_idx % 100 == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                  Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise)
                # take out (up to) 32 examples
                img_grid_real = torchvision.utils.make_grid(
                    real[:32], normalize=True
                )
                img_grid_fake = torchvision.utils.make_grid(
                    fake[:32], normalize=True
                )

                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)

            step += 1

img_grid_fake

"""To visualize the generated images using TensorBoard, run the following commands in new code cells:"""

# Commented out IPython magic to ensure Python compatibility.
# %load_ext tensorboard
# %tensorboard --logdir logs

"""After running these cells, a TensorBoard interface will appear in your Colab output, where you can see the real and generated images over time."""